{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.functional as F\n",
    "from torch.autograd import Variable\n",
    "import torch.optim as optim\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "chars = [c for c in 'SEPabcdefghijklmnopqrstuvwxyz']\n",
    "char_dic = {w:i for i,w in enumerate(chars)}\n",
    "pairs = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low']]\n",
    "\n",
    "#parameters\n",
    "n_step = 5\n",
    "n_hidden = 128\n",
    "n_class = len(char_dic)\n",
    "batch_size = len(pairs)\n",
    "\n",
    "\n",
    "def make_batch(pairs):\n",
    "    input_batch = []\n",
    "    output_batch = []\n",
    "    target_batch = []\n",
    "    for pair in pairs:\n",
    "        for i in range(2):\n",
    "            pair[i] = pair[i]+'P'*(n_step-len(pair[i]))\n",
    "        input = [char_dic[n] for n in pair[0]]\n",
    "        output = [char_dic[n] for n in ('S'+pair[1])]\n",
    "        target = [char_dic[n] for n in (pair[1]+'E')]\n",
    "        input_batch.append(np.eye(n_class)[input])\n",
    "        output_batch.append(np.eye(n_class)[output])\n",
    "        target_batch.append(target)\n",
    "        \n",
    "    return Variable(torch.Tensor(input_batch)),Variable(torch.Tensor(output_batch)),Variable(torch.Tensor(target_batch).type(torch.LongTensor))\n",
    "\n",
    "class Seq2Seq(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Seq2Seq,self).__init__()\n",
    "        self.enc_cell = nn.RNN(input_size=n_class,hidden_size=n_hidden,dropout=0.5)\n",
    "        self.dec_cell = nn.RNN(input_size=n_class,hidden_size=n_hidden,dropout=0.5)\n",
    "        self.fc = nn.Linear(n_hidden,n_class)\n",
    "    \n",
    "    def forward(self,enc_input,enc_hidden,dec_input):\n",
    "        enc_input = enc_input.transpose(0,1)#[batch_size,n_step,hidden_size]->[n_step,batch_size,hidden_size]\n",
    "        dec_input = dec_input.transpose(0,1)#[n_step,batch_size,hidden_size]\n",
    "        \n",
    "        #enc_states : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]\n",
    "        _,enc_states = self.enc_cell(enc_input,enc_hidden)\n",
    "        #dec_output : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]\n",
    "        dec_output,_ = self.dec_cell(dec_input,enc_states)\n",
    "        \n",
    "        model = self.fc(dec_output)\n",
    "        return model  # model : [max_len+1(=6), batch_size, n_class]\n",
    "    \n",
    "def translate(word):\n",
    "    input_batch,output_batch,_ = make_batch([[word,'P'*len(word)]])\n",
    "    hidden = Variable(torch.zeros(1,1,n_hidden))\n",
    "    output = model(input_batch,hidden,output_batch)\n",
    "    predict = output.data.max(2,keepdim=True)[1]\n",
    "    decoded = [chars[i] for i in predict]\n",
    "    end = decoded.index('E')\n",
    "    translated = ''.join(decoded[:end])\n",
    "    return translated.replace('P', '')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/caijie/anaconda3/lib/python3.6/site-packages/torch/nn/modules/rnn.py:38: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.5 and num_layers=1\n",
      "  \"num_layers={}\".format(dropout, num_layers))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:100,loss:0.010268\n",
      "epoch:200,loss:0.003431\n",
      "epoch:300,loss:0.001938\n",
      "epoch:400,loss:0.001276\n",
      "epoch:500,loss:0.000915\n",
      "epoch:600,loss:0.000694\n",
      "epoch:700,loss:0.000547\n",
      "epoch:800,loss:0.000443\n",
      "epoch:900,loss:0.000367\n",
      "epoch:1000,loss:0.000309\n",
      "epoch:1100,loss:0.000265\n",
      "epoch:1200,loss:0.000228\n",
      "epoch:1300,loss:0.000199\n",
      "epoch:1400,loss:0.000174\n",
      "epoch:1500,loss:0.000154\n",
      "epoch:1600,loss:0.000138\n",
      "epoch:1700,loss:0.000125\n",
      "epoch:1800,loss:0.000111\n",
      "epoch:1900,loss:0.000100\n",
      "epoch:2000,loss:0.000091\n",
      "epoch:2100,loss:0.000083\n",
      "epoch:2200,loss:0.000076\n",
      "epoch:2300,loss:0.000070\n",
      "epoch:2400,loss:0.000063\n",
      "epoch:2500,loss:0.000058\n",
      "epoch:2600,loss:0.000054\n",
      "epoch:2700,loss:0.000050\n",
      "epoch:2800,loss:0.000047\n",
      "epoch:2900,loss:0.000044\n",
      "epoch:3000,loss:0.000038\n",
      "epoch:3100,loss:0.000036\n",
      "epoch:3200,loss:0.000034\n",
      "epoch:3300,loss:0.000032\n",
      "epoch:3400,loss:0.000030\n",
      "epoch:3500,loss:0.000028\n",
      "epoch:3600,loss:0.000025\n",
      "epoch:3700,loss:0.000024\n",
      "epoch:3800,loss:0.000023\n",
      "epoch:3900,loss:0.000021\n",
      "epoch:4000,loss:0.000020\n",
      "epoch:4100,loss:0.000019\n",
      "epoch:4200,loss:0.000017\n",
      "epoch:4300,loss:0.000016\n",
      "epoch:4400,loss:0.000015\n",
      "epoch:4500,loss:0.000014\n",
      "epoch:4600,loss:0.000013\n",
      "epoch:4700,loss:0.000012\n",
      "epoch:4800,loss:0.000011\n",
      "epoch:4900,loss:0.000010\n",
      "epoch:5000,loss:0.000010\n"
     ]
    }
   ],
   "source": [
    "model = Seq2Seq()\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(),lr = 0.01)\n",
    "\n",
    "input_batch,output_batch,target_batch = make_batch(pairs)\n",
    "\n",
    "for epoch in range(5000):\n",
    "    optimizer.zero_grad()\n",
    "    enc_hidden = Variable(torch.zeros(1,batch_size,n_hidden))\n",
    "    output = model(input_batch,enc_hidden,output_batch)\n",
    "    output = output.transpose(0,1)\n",
    "    loss = 0\n",
    "    for i in range(len(target_batch)):\n",
    "        loss += criterion(output[i],target_batch[i])\n",
    "    if (epoch+1)%1000 == 0:\n",
    "        print('epoch:%d,loss:%f' % (epoch+1,loss))\n",
    "    loss.backward()\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test\n",
      "man -> women\n",
      "mans -> women\n",
      "king -> queen\n",
      "black -> white\n",
      "upp -> down\n"
     ]
    }
   ],
   "source": [
    "print('test')\n",
    "print('man ->', translate('man'))\n",
    "print('mans ->', translate('mans'))\n",
    "print('king ->', translate('king'))\n",
    "print('black ->', translate('black'))\n",
    "print('upp ->', translate('upp'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
